/* =========================================================================
 * MPU6050 ROCKET STABILIZATION - ACCESS POINT (AP) VERSION
 * =========================================================================
 * CHANGES FROM PREVIOUS:
 * 1. WiFi Mode: Creates its own WiFi network (SoftAP)
 * 2. Addressing: Defaults to BROADCAST IP until it hears from a specific computer
 * 3. Discovery: Auto-detects the computer's IP when it receives a command
 * 4. PID Tuning: Added runtime PID tuning and saving to internal Flash (NVS)
 * ========================================================================= */

#include <Arduino.h>
#include <Wire.h>
#include <Adafruit_MPU6050.h>
#include <Adafruit_Sensor.h>
#include <ESP32Servo.h>
#include <WiFi.h>
#include <WiFiUdp.h>
#include <Preferences.h>

// --- WiFi / Network Configuration ---
// The Rocket will CREATE this network
const char* ssid = "RocketLink_Telemetry";
const char* password = "rocketlaunch"; 

// Telemetry Destination
// Initially set to BROADCAST so any connected computer can hear it immediately.
// Once a command is received, this updates to the specific sender's IP.
IPAddress remoteIp(192, 168, 4, 255); 
const int remotePort = 4444; // Port to send telemetry TO
const int localPort = 4444;  // Port to listen for commands ON

WiFiUDP udp;
Preferences preferences;

// --- Hardware Pin Definitions ---
const int LEFT_SERVO_PIN = 26;
const int RIGHT_SERVO_PIN = 25;
const int UP_SERVO_PIN = 27;
const int DOWN_SERVO_PIN = 14;

// --- Servo Settings ---
Servo leftServo;
Servo rightServo;
Servo upServo;
Servo downServo;

const int LEFT_CENTER = 115;
const int RIGHT_CENTER = 80;
const int UP_CENTER = 80;
const int DOWN_CENTER = 115;
const int MAX_DEFLECTION = 15; 

// --- PID & Stabilization ---
Adafruit_MPU6050 mpu;
float Kp = 0.5;   // Default, overridden by preferences in setup
float Kd = 0.2;   // Default, overridden by preferences in setup
float roll = 0;
float gyroX_offset = 0; 

// --- State Machine ---
enum RocketState {
  STATE_IDLE,      
  STATE_LAUNCHING, 
  STATE_FLIGHT     
};

RocketState currentState = STATE_IDLE;

// --- Launch Detection Constants ---
const float LAUNCH_THRESHOLD_G = 1.6;
const float G_TO_MS2 = 9.80665;
const float LAUNCH_THRESHOLD_MS2 = LAUNCH_THRESHOLD_G * G_TO_MS2; 
const int LAUNCH_DURATION_MS = 50;
unsigned long launchTriggerStartTime = 0;

// --- Timing Variables ---
unsigned long last_time;
float dt;
unsigned long lastStateMsgTime = 0;
const int STATE_MSG_INTERVAL = 2000; 

// --- Function Prototypes ---
void setServosEnabled(bool enabled);
void changeState(RocketState newState);
void sendTelemetry(float roll, float rate, int out);
void sendStateUpdate();
void calibrateGyro();

void setup(void) {
  Serial.begin(115200);
  Wire.begin(21, 22);

  // 1. Setup Servos
  ESP32PWM::allocateTimer(0);
  ESP32PWM::allocateTimer(1);
  ESP32PWM::allocateTimer(2);
  ESP32PWM::allocateTimer(3);
  
  leftServo.setPeriodHertz(50);
  rightServo.setPeriodHertz(50);
  upServo.setPeriodHertz(50);
  downServo.setPeriodHertz(50);
  setServosEnabled(false);

  // 2. Setup MPU6050
  if (!mpu.begin()) {
    Serial.println("Failed to find MPU6050 chip");
    while (1) { delay(10); }
  }
  Serial.println("MPU6050 Found!");
  mpu.setAccelerometerRange(MPU6050_RANGE_16_G); 
  mpu.setGyroRange(MPU6050_RANGE_500_DEG);
  mpu.setFilterBandwidth(MPU6050_BAND_21_HZ);

  // 3. Load Calibration and PID config
  preferences.begin("rocket_cfg", false); 
  gyroX_offset = preferences.getFloat("gyro_bias", 0.0);
  Kp = preferences.getFloat("Kp", 0.5); // Load saved Kp, default to 0.5
  Kd = preferences.getFloat("Kd", 0.2); // Load saved Kd, default to 0.2

  Serial.println("--- Saved Config Loaded ---");
  Serial.print("Gyro Bias: "); Serial.println(gyroX_offset);
  Serial.print("Kp: "); Serial.println(Kp);
  Serial.print("Kd: "); Serial.println(Kd);
  Serial.println("---------------------------");

  // 4. Setup Access Point (AP)
  Serial.println("Starting Access Point...");
  WiFi.softAP(ssid, password);
  
  // The default IP for ESP32 SoftAP is usually 192.168.4.1
  Serial.print("AP Created. Connect to: "); Serial.println(ssid);
  Serial.print("Rocket IP: "); Serial.println(WiFi.softAPIP());
  
  udp.begin(localPort);

  Serial.println("System Ready. Broadcasting Telemetry...");
  last_time = millis();
}

void loop() {
  sensors_event_t a, g, temp;
  mpu.getEvent(&a, &g, &temp);

  // Calculate Delta Time
  unsigned long current_time = millis();
  dt = (current_time - last_time) / 1000.0;
  last_time = current_time;
  
  if (dt <= 0) return; 

  // --- Network: Receive Commands & Auto-Detect Remote IP ---
  int packetSize = udp.parsePacket();
  if (packetSize) {
    // CAPTURE SENDER IP: This enables 2-way comms without hardcoding
    IPAddress senderIp = udp.remoteIP();
    if (remoteIp != senderIp) {
        remoteIp = senderIp;
        Serial.print("Connected to Ground Station at: ");
        Serial.println(remoteIp);
    }

    char packetBuffer[255];
    int len = udp.read(packetBuffer, 255);
    if (len > 0) packetBuffer[len] = 0;
    
    String command = String(packetBuffer);
    command.trim();

    // -- Parse Tuning Commands (Any state) --
    if (command.startsWith("PID,")) {
        // Expected format: PID,0.5,0.2
        int firstComma = command.indexOf(',');
        int secondComma = command.indexOf(',', firstComma + 1);
        
        if (firstComma > 0 && secondComma > 0) {
            Kp = command.substring(firstComma + 1, secondComma).toFloat();
            Kd = command.substring(secondComma + 1).toFloat();
            
            // Save to internal flash so it survives reboots
            preferences.putFloat("Kp", Kp);
            preferences.putFloat("Kd", Kd);
            
            Serial.print("New PID Saved -> Kp: "); Serial.print(Kp);
            Serial.print(" Kd: "); Serial.println(Kd);
            
            // Send updated state immediately so dashboard syncs
            sendStateUpdate();
        }
    }
    // -- Parse Mission Commands (Only in IDLE) --
    else if (currentState == STATE_IDLE) {
      if (command.equalsIgnoreCase("launch")) {
         changeState(STATE_LAUNCHING);
      }
      else if (command.equalsIgnoreCase("calibrate")) {
         Serial.println("Recalibration Command Received.");
         calibrateGyro();
         roll = 0; 
      }
    }
  }

  // --- Network: Send State Heartbeat ---
  if (millis() - lastStateMsgTime > STATE_MSG_INTERVAL) {
    sendStateUpdate();
    lastStateMsgTime = millis();
  }

  // --- Gyro Integration ---
  float raw_rate_rad = g.gyro.x - gyroX_offset;
  float rate_deg_s = raw_rate_rad * 180.0 / PI;
  roll += rate_deg_s * dt;

  int currentServoOffset = 0;

  // --- STATE MACHINE LOGIC ---
  switch (currentState) {
    case STATE_IDLE:
      break;

    case STATE_LAUNCHING:
      {
        float accel_mag = sqrt(sq(a.acceleration.x) + sq(a.acceleration.y) + sq(a.acceleration.z));
        if (accel_mag > LAUNCH_THRESHOLD_MS2) {
          if (launchTriggerStartTime == 0) {
            launchTriggerStartTime = millis(); 
          } else if (millis() - launchTriggerStartTime > LAUNCH_DURATION_MS) {
            changeState(STATE_FLIGHT);
          }
        } else {
          launchTriggerStartTime = 0;
        }
      }
      break;

    case STATE_FLIGHT:
      {
        float P_term = Kp * roll;
        float D_term = Kd * rate_deg_s;
        float output = P_term + D_term;

        currentServoOffset = (int)output;
        currentServoOffset = constrain(currentServoOffset, -MAX_DEFLECTION, MAX_DEFLECTION);

        leftServo.write(LEFT_CENTER + currentServoOffset);
        rightServo.write(RIGHT_CENTER + currentServoOffset);
        upServo.write(UP_CENTER + currentServoOffset);
        downServo.write(DOWN_CENTER + currentServoOffset);
      }
      break;
  }

  // --- TELEMETRY ---
  sendTelemetry(roll, rate_deg_s, currentServoOffset);

  delay(5); 
}

// --- Helper Functions ---

void calibrateGyro() {
  Serial.println("Calibrating Gyro... KEEP STILL");
  delay(1000); 
  float sumX = 0;
  int calibration_samples = 200;
  for (int i = 0; i < calibration_samples; i++) {
    sensors_event_t a, g, temp;
    mpu.getEvent(&a, &g, &temp);
    sumX += g.gyro.x;
    delay(5);
  }
  gyroX_offset = sumX / calibration_samples;
  preferences.putFloat("gyro_bias", gyroX_offset);
  Serial.print("New Gyro Offset Saved: "); Serial.println(gyroX_offset);
}

void changeState(RocketState newState) {
  currentState = newState;
  switch (currentState) {
    case STATE_IDLE:
      setServosEnabled(false);
      break;
    case STATE_LAUNCHING:
      setServosEnabled(false); 
      launchTriggerStartTime = 0;
      roll = 0; 
      break;
    case STATE_FLIGHT:
      setServosEnabled(true); 
      break;
  }
  sendStateUpdate();
  lastStateMsgTime = millis();
}

void setServosEnabled(bool enabled) {
  if (enabled) {
    if (!leftServo.attached()) leftServo.attach(LEFT_SERVO_PIN, 500, 2400);
    if (!rightServo.attached()) rightServo.attach(RIGHT_SERVO_PIN, 500, 2400);
    if (!upServo.attached()) upServo.attach(UP_SERVO_PIN, 500, 2400);
    if (!downServo.attached()) downServo.attach(DOWN_SERVO_PIN, 500, 2400);
  } else {
    if (leftServo.attached()) leftServo.detach();
    if (rightServo.attached()) rightServo.detach();
    if (upServo.attached()) upServo.detach();
    if (downServo.attached()) downServo.detach();
  }
}

void sendStateUpdate() {
    // Send to the current remoteIp (Starts as broadcast .255, then becomes specific)
    // Modified to include Kp and Kd so dashboard stays in sync
    udp.beginPacket(remoteIp, remotePort);
    udp.print("STATUS:");
    switch (currentState) {
      case STATE_IDLE: udp.print("IDLE"); break;
      case STATE_LAUNCHING: udp.print("LAUNCHING"); break;
      case STATE_FLIGHT: udp.print("FLIGHT"); break;
    }
    udp.print(",");
    udp.print(Kp);
    udp.print(",");
    udp.print(Kd);
    udp.endPacket();
}

void sendTelemetry(float roll, float rate, int out) {
    udp.beginPacket(remoteIp, remotePort);
    udp.print("T,");
    udp.print(millis());
    udp.print(",");
    udp.print(roll, 1);
    udp.print(",");
    udp.print(rate, 1);
    udp.print(",");
    udp.print(out);
    udp.endPacket();
}